# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from __future__ import annotations

import logging
import math
from functools import lru_cache
from typing import Optional

import torch
from torch import Tensor

from megatron.core.models.common.embeddings.rope_utils import get_pos_emb_on_this_cp_rank
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.transformer import TransformerConfig

logger = logging.getLogger(__name__)


class YarnRotaryEmbedding(RotaryEmbedding):
    """Yarn Rotary Embedding for language model.

    Args:
        kv_channels (int): Projection weights dimension in multi-head attention. Obtained from
            transformer config.
        rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
        rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
            Defaults to False.
        seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for
            longer sequences. The value must be a float larger than 1.0. Defaults to None
        rotary_base (float, optional): Base period for rotary position embeddings. Defaults to
            10000.
        use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly on
            the GPU. Defaults to False.
        scaling_factor (float, optional): Scaling factor for Yarn RoPE. Defaults to 1.0.
        original_max_position_embeddings (int, optional): Original maximum position embeddings
            length. Defaults to 4096.
        beta_fast (float, optional): Fast beta value for Yarn RoPE. Defaults to 32.
        beta_slow (float, optional): Slow beta value for Yarn RoPE. Defaults to 1.
        mscale (float, optional): Mscale value for Yarn RoPE. Defaults to 1.
        mscale_all_dim (float, optional): Mscale all dim value for Yarn RoPE. Defaults to 0.
        correction_range_round_to_int (bool): Whether to round dim range bounds to integer.
            Defaults to True
        cp_group (torch.distributed.ProcessGroup, optional): Process group for context parallel.
            Defaults to None.
    """

    def __init__(
        self,
        kv_channels: int,
        rotary_percent: float = 1.0,
        rotary_interleaved: bool = False,
        seq_len_interpolation_factor: Optional[float] = None,
        rotary_base: float = 10000.0,
        use_cpu_initialization: bool = False,
        scaling_factor: float = 1.0,
        original_max_position_embeddings: int = 4096,
        beta_fast: float = 32.0,
        beta_slow: float = 1.0,
        mscale: float = 1.0,
        mscale_all_dim: float = 0.0,
        correction_range_round_to_int: bool = True,
        cp_group: Optional[torch.distributed.ProcessGroup] = None,
    ):
        self.dim = kv_channels
        self.rotary_base = rotary_base
        self.scaling_factor = scaling_factor
        self.original_max_position_embeddings = original_max_position_embeddings
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        self.mscale = mscale
        self.mscale_all_dim = mscale_all_dim
        self.correction_range_round_to_int = correction_range_round_to_int

        device = 'cpu' if use_cpu_initialization else torch.cuda.current_device()

        with torch.device(device):
            self.inv_freq_extra = 1.0 / (
                self.rotary_base
                ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)
            )
            self.inv_freq_inter = 1.0 / (
                self.scaling_factor
                * self.rotary_base
                ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)
            )
            super().__init__(
                kv_channels=kv_channels,
                rotary_percent=rotary_percent,
                rotary_interleaved=rotary_interleaved,
                seq_len_interpolation_factor=seq_len_interpolation_factor,
                rotary_base=rotary_base,
                use_cpu_initialization=use_cpu_initialization,
                cp_group=cp_group,
            )

            self._set_cos_sin_cache(
                self.original_max_position_embeddings, offset=0, dtype=torch.get_default_dtype()
            )

            # clear the lru_cache for the forward method. If not cleared, the cache of forward
            # method causes a memory leak in NeMo-RL.
            self.forward.cache_clear()

    @lru_cache(maxsize=32)
    def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor:
        """Forward pass of Yarn Rotary Embedding.

        Args:
            max_seq_len (int): Maximum size of sequence
            offset (int, optional): RoPE offset. Defaults to 0.
            packed_seq (bool, optional): Whether to use packed sequence. Defaults to False.

        Returns:
            Tensor: Embeddings after applying Yarn RoPE.
        """
        assert (
            not self.rotary_interleaved
        ), "Yarn RoPE does not support interleaved rotary embeddings"

        if self.inv_freq_extra.device.type == 'cpu':
            # move `inv_freq_extra` to GPU once at the first micro-batch forward pass
            self.inv_freq_extra = self.inv_freq_extra.to(device=torch.cuda.current_device())

        if self.inv_freq_inter.device.type == 'cpu':
            # move `inv_freq_inter` to GPU once at the first micro-batch forward pass
            self.inv_freq_inter = self.inv_freq_inter.to(device=torch.cuda.current_device())

        low, high = _yarn_find_correction_range(
            self.beta_fast,
            self.beta_slow,
            self.dim,
            self.rotary_base,
            self.original_max_position_embeddings,
            self.correction_range_round_to_int,
        )
        inv_freq_mask = 1.0 - _yarn_linear_ramp_mask(
            low, high, self.dim // 2, device=self.inv_freq_extra.device
        ).to(dtype=torch.float32)
        inv_freq = self.inv_freq_inter * (1 - inv_freq_mask) + self.inv_freq_extra * inv_freq_mask

        seq = (
            torch.arange(
                max_seq_len, device=self.inv_freq_extra.device, dtype=self.inv_freq_extra.dtype
            )
            + offset
        )

        freqs = torch.outer(seq, inv_freq)

        _mscale = _yarn_get_concentration_factor(
            self.scaling_factor, self.mscale, self.mscale_all_dim
        )

        emb = torch.cat((freqs, freqs), dim=-1)
        # emb [seq_length, .., dim]
        emb = emb[:, None, None, :]
        if self.cp_group is not None and self.cp_group.size() > 1 and not packed_seq:
            # slice rotary_pos_emb along sequence dimension
            # and select the parition of the current CP rank
            emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group)
        return emb, _mscale

    def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq=False):
        self.max_seq_len_cached = seq_len
        self.offset_cached = offset
        self.dtype_cached = dtype
        self.packed_seq_cached = packed_seq

        emb, _mscale = self.forward(seq_len, offset, packed_seq)
        self.register_buffer(
            "cos_cached", (emb.cos() * _mscale).to(dtype).contiguous(), persistent=False
        )
        self.register_buffer(
            "sin_cached", (emb.sin() * _mscale).to(dtype).contiguous(), persistent=False
        )

    def get_cached_cos_sin(
        self, seq_len, offset=0, dtype=torch.get_default_dtype(), packed_seq=False
    ):
        """Get cached cos and sin values."""
        if (
            seq_len > self.max_seq_len_cached
            or offset != self.offset_cached
            or dtype != self.dtype_cached
            or packed_seq != self.packed_seq_cached
        ):
            self._set_cos_sin_cache(seq_len, offset, dtype, packed_seq)
        return (self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...])


# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(
    num_rotations: float, dim: int, rotary_base: float = 10000, max_position_embeddings: int = 2048
) -> float:
    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
        2 * math.log(rotary_base)
    )


# Find dim range bounds based on rotations
def _yarn_find_correction_range(
    low_rot: float,
    high_rot: float,
    dim: int,
    rotary_base: float = 10000,
    max_position_embeddings: int = 2048,
    round_to_int: bool = True,
) -> tuple[int, int]:
    low = _yarn_find_correction_dim(low_rot, dim, rotary_base, max_position_embeddings)
    high = _yarn_find_correction_dim(high_rot, dim, rotary_base, max_position_embeddings)
    if round_to_int:
        low = math.floor(low)
        high = math.ceil(high)
    return max(low, 0), min(high, dim - 1)  # Clamp values just in case


def _yarn_linear_ramp_mask(min: float, max: float, dim: int, device: torch.device) -> Tensor:
    if min == max:
        max += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func


def _yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


@lru_cache(maxsize=8)
def _yarn_get_concentration_factor(
    scaling_factor: float, mscale: Optional[float], mscale_all_dim: Optional[float]
) -> float:
    """
    Get the concentration factor (factor multiplied to the sine and cosine components of the
    embedding). This factor is also known as attention factor, and sometimes homonymously known as
    "mscale"
    """
    if mscale is None or mscale_all_dim is None:
        return _yarn_get_mscale(scaling_factor)
    return float(
        _yarn_get_mscale(scaling_factor, mscale) / _yarn_get_mscale(scaling_factor, mscale_all_dim)
    )


def _yarn_get_concentration_factor_from_config(config: TransformerConfig) -> float:
    if hasattr(config, "yarn_rotary_scaling_factor"):
        return _yarn_get_concentration_factor(
            config.yarn_rotary_scaling_factor,
            getattr(config, "yarn_mscale", None),
            getattr(config, "yarn_mscale_all_dim", None),
        )
    return 1.0
